{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Truncated SVD\n",
    "\n",
    "This is a simple application of [Truncated SVD](http://langvillea.people.cofc.edu/DISSECTION-LAB/Emmie'sLSI-SVDModule/p5module.html), just to get a feeling of what happens to the accuracy if we use TruncatedSVD **w/o fine-tuning.**\n",
    "\n",
    "We apply Truncated SVD on the linear layer found at the end of ResNet50, and run a test over the validation dataset to measure the impact on the classification accuracy.\n",
    "\n",
    "Spoiler: \n",
    "\n",
    "At k=800 (80%) we get\n",
    "    - Top1: 76.02 \n",
    "    - Top5: 92.86\n",
    "    \n",
    "    Total weights: 1000 * 800 + 800 * 2048 = 2,438,400 (vs. 1000 * 2048 = 2,048,000)\n",
    "    \n",
    "At k=700 (70%) we get\n",
    "    - Top1: 76.03 \n",
    "    - Top5: 92.85\n",
    "    \n",
    "    Total weights: 1000 * 700 + 700 * 2048 = 2,133,600 (vs. 2,048,000)    \n",
    "\n",
    "At k=600 (60%) we get\n",
    "    - Top1: 75.98 \n",
    "    - Top5: 92.82\n",
    "    \n",
    "    Total weights: 1000 * 600 + 600 * 2048 = 1,828,800 (vs. 2,048,000) \n",
    "    \n",
    "At k=500 (50%) we get\n",
    "    - Top1: 75.78  \n",
    "    - Top5: 92.77\n",
    "    \n",
    "    Total weights: 1000 * 500 + 500 * 2048 = 1,524,000 (vs. 2,048,000) \n",
    "\n",
    "At k=400 (40%) we get\n",
    "    - Top1: 75.65  \n",
    "    - Top5: 92.75\n",
    "    \n",
    "    Total weights: 1000 * 400 + 400 * 2048 = 1,219,200 (vs. 2,048,000)\n",
    "    \n",
    "## Details\n",
    "\n",
    "[SVD (Singular Value Decomposition)](https://en.wikipedia.org/wiki/Singular_value_decomposition) is an exact factorization of a matrix, W (of shape m x n), to the form USV$^T$ (U is m x m, S is m x n, V$^T$ is n x n; V$^T$ is the transpose of V). Every matrix has an SVD.\n",
    "\n",
    "A Linear (fully-connected) layer performs: y = Wx + b   (or y = xW$^T$ + b)<br>\n",
    "We can use SVD to refactor W to rewrite this as: y = (USV$^T$)x + b\n",
    "\n",
    "So far, we haven’t done any compression, so let’s get to it using Truncated SVD.<br> TSVD is a method to provide an approximated decomposition of W, in which S has a lower rank. We want to find an approximation of W that is “good enough” and also accelerates the computation of Wx.\n",
    "\n",
    "We choose some lower-rank, k, such that k<m (preferably k<<m).<br>\n",
    "TSVD is straight-forward: keep the largest k singular values of S and discard the rest (truncate S).\n",
    "\n",
    "After TSVD we have:\n",
    "U’ is m x t, S’ is k x k, V’$^T$ is k x n.<br>\n",
    "y ~ (U’S’V’)x + b<br>\n",
    "y ~ (U’(S’V’))x + b<br>\n",
    "\n",
    "We'll replace S’V’$^T$ with A, because we can pre-compute it once. A has shape k x n:<br>\n",
    "y ~ (U’A)x + b<br>\n",
    "y ~ U’(Ax) + b<br>\n",
    "\n",
    "Let’s ignore the bias and calculate the number of parameters and FLOPs (floating point operations) for the original y:\n",
    "\n",
    " - m * n weights coefficients<br>\n",
    " - m * n FLOPs (for batch size = 1)<br>\n",
    "\n",
    "After TSVD we have:\n",
    "\n",
    " - mk + kn = k*(m+n) weights coefficients<br>\n",
    " - kn + mk = k*(m+n) FLOPs (for batch size = 1)<br>\n",
    "\n",
    "To actually compress the weights after TSVD, we want: m * n > k*(m+n)<br>\n",
    "Let’s rewrite k in terms of m: k = tm<br>\n",
    "m * n > tm*(m+n)<br>\n",
    "n > t*(m+n)<br>\n",
    "n / (m+n) >= t<br>\n",
    "\n",
    "This is the math, but for an actual performance increase, we should strive for m * n >> k*(m+n)\n",
    "\n",
    "In the example notebook: m = 1000; n=2048<br>\n",
    "So when t=2048/(1000+2048) (that is, k=2048/3048*1000=672), we have equilibrium. When 0.672>t (i.e. k is smaller than 672), the sum of the size of the weights of A and U’ is smaller than the size of W.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "import scipy.stats as ss\n",
    "\n",
    "# Relative import of code from distiller, w/o installing the package\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import scipy   \n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "import distiller\n",
    "import distiller.apputils as apputils\n",
    "import distiller.models as models\n",
    "from distiller.apputils import *\n",
    "\n",
    "plt.style.use('seaborn') # pretty matplotlib plots"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def imagenet_load_data(data_dir, batch_size, num_workers, shuffle=True):\n",
    "    \"\"\"Load the ImageNet dataset\"\"\"\n",
    "    test_dir = os.path.join(data_dir, 'val')\n",
    "    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                     std=[0.229, 0.224, 0.225])\n",
    "\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.ImageFolder(test_dir, transforms.Compose([\n",
    "            transforms.Resize(256),\n",
    "            transforms.CenterCrop(224),\n",
    "            transforms.ToTensor(),\n",
    "            normalize,\n",
    "        ])),\n",
    "        batch_size=batch_size, shuffle=shuffle,\n",
    "        num_workers=num_workers, pin_memory=True)\n",
    "\n",
    "    return test_loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 4\n",
    "\n",
    "# Data loader\n",
    "test_loader = imagenet_load_data(\"/datasets/imagenet/\", \n",
    "                                 batch_size=BATCH_SIZE, \n",
    "                                 num_workers=2)\n",
    "    \n",
    "    \n",
    "# Reverse the normalization transformations we performed when we loaded the data \n",
    "# for consumption by our CNN.\n",
    "# See: https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/3\n",
    "invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],\n",
    "                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),\n",
    "                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],\n",
    "                                                     std = [ 1., 1., 1. ]),\n",
    "                               ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load ResNet 50"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fc_layer(2048, 1000)\n",
      "W = torch.Size([1000, 2048])\n",
      "U = torch.Size([1000, 400])\n",
      "SV = torch.Size([400, 2048])\n"
     ]
    }
   ],
   "source": [
    "# Load the various models\n",
    "resnet50 = models.create_model(pretrained=True, dataset='imagenet', arch='resnet50', parallel=False)\n",
    "\n",
    "\n",
    "# See Faster-RCNN: https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/compress_net.py\n",
    "# Replaced numpy operations with pytorch operations (so that we can leverage the GPU).\n",
    "def truncated_svd(W, l):\n",
    "    \"\"\"Compress the weight matrix W of an inner product (fully connected) layer\n",
    "    using truncated SVD.\n",
    "    Parameters:\n",
    "    W: N x M weights matrix\n",
    "    l: number of singular values to retain\n",
    "    Returns:\n",
    "    Ul, L: matrices such that W \\approx Ul*L\n",
    "    \"\"\"\n",
    "\n",
    "    U, s, V = torch.svd(W, some=True)\n",
    "\n",
    "    Ul = U[:, :l]\n",
    "    sl = s[:l]\n",
    "    V = V.t()\n",
    "    Vl = V[:l, :]\n",
    "\n",
    "    SV = torch.mm(torch.diag(sl), Vl)\n",
    "    return Ul, SV\n",
    "\n",
    "\n",
    "class TruncatedSVD(nn.Module):\n",
    "    def __init__(self, replaced_gemm, gemm_weights, preserve_ratio):\n",
    "        super().__init__()\n",
    "        self.replaced_gemm = replaced_gemm\n",
    "        print(\"W = {}\".format(gemm_weights.shape))\n",
    "        self.U, self.SV = truncated_svd(gemm_weights.data, int(preserve_ratio * gemm_weights.size(0)))\n",
    "        print(\"U = {}\".format(self.U.shape))\n",
    "        \n",
    "        self.fc_u = nn.Linear(self.U.size(1), self.U.size(0)).cuda()\n",
    "        self.fc_u.weight.data = self.U\n",
    "        \n",
    "        print(\"SV = {}\".format(self.SV.shape))\n",
    "        self.fc_sv = nn.Linear(self.SV.size(1), self.SV.size(0)).cuda()\n",
    "        self.fc_sv.weight.data = self.SV#.t()        \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc_sv.forward(x)\n",
    "        x = self.fc_u.forward(x)\n",
    "        return x\n",
    "\n",
    "    \n",
    "def replace(model):\n",
    "    fc_weights = model.state_dict()['fc.weight']\n",
    "    fc_layer = model.fc\n",
    "    print(\"fc_layer({}, {})\".format(fc_layer.in_features, fc_layer.out_features))\n",
    "    model.fc = TruncatedSVD(fc_layer, fc_weights, 0.4)\n",
    "\n",
    "    \n",
    "from copy import deepcopy\n",
    "resnet50 = deepcopy(resnet50)\n",
    "replace(resnet50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "progress: 51200 images\n",
      "progress: 102400 images\n",
      "progress: 153600 images\n",
      "progress: 204800 images\n",
      "progress: 256000 images\n",
      "progress: 307200 images\n",
      "progress: 358400 images\n",
      "Top1: 75.70   Top5: 92.76\n",
      "Duration:  168.111492395401\n"
     ]
    }
   ],
   "source": [
    "# Standard loop to test the accuracy of a model.\n",
    "\n",
    "import time\n",
    "import torchnet.meter as tnt\n",
    "t0 = time.time()\n",
    "test_loader = imagenet_load_data(\"/datasets/imagenet\", \n",
    "                                 batch_size=64, \n",
    "                                 num_workers=4,\n",
    "                                 shuffle=False)\n",
    "\n",
    "t1 = time.time()\n",
    "classerr = tnt.ClassErrorMeter(accuracy=True, topk=(1, 5))\n",
    "resnet50.eval()\n",
    "for validation_step, (inputs, target) in enumerate(test_loader):\n",
    "    with torch.no_grad():\n",
    "        inputs, target = inputs.to('cuda'), target.to('cuda')\n",
    "        outputs = resnet50(inputs)\n",
    "        classerr.add(outputs.data, target)\n",
    "        if (validation_step+1) % 100 == 0:\n",
    "            print(\"progress: %d images\" % ((validation_step+1) * 512))\n",
    "        \n",
    "print(\"Top1: %.2f   Top5: %.2f\" % (classerr.value(1), classerr.value(5)))\n",
    "t2 = time.time()\n",
    "print(\"Duration: \", t2-t0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
